/*
Copyright 2016 Fixstars Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http ://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include <iostream>
#include <iomanip>
#include <string>
#include <chrono>
#include <boost/bind.hpp>

#include <cuda_runtime.h>
#include <opencv4/opencv2/core/core.hpp>
#include <opencv4/opencv2/highgui/highgui.hpp>
#include <opencv4/opencv2/imgproc/imgproc.hpp>
#include <opencv4/opencv2/core/version.hpp>
#include <cv_bridge/cv_bridge.h>

#include <libsgm.h>

#include <ros/ros.h>
#include <sensor_msgs/Image.h>
#include <sensor_msgs/image_encodings.h>
#include <ros/spinner.h>
#include <sensor_msgs/CameraInfo.h>
#include <message_filters/subscriber.h>
#include <message_filters/synchronizer.h>
#include <message_filters/time_synchronizer.h>
#include <message_filters/sync_policies/approximate_time.h>

ros::Publisher depth_pub;
#define ASSERT_MSG(expr, msg) \
    if (!(expr)) { \
        std::cerr << msg << std::endl; \
        std::exit(EXIT_FAILURE); \
    } \

struct device_buffer
{
    device_buffer() : data(nullptr) {}
    device_buffer(size_t count) { allocate(count); }
    void allocate(size_t count) { cudaMalloc(&data, count); }
    ~device_buffer() { cudaFree(data); }
    void* data;
};

template <class... Args>
static std::string format_string(const char* fmt, Args... args)
{
    const int BUF_SIZE = 1024;
    char buf[BUF_SIZE];
    std::snprintf(buf, BUF_SIZE, fmt, args...);
    return std::string(buf);
}

/**
    之前有关image转换和callback报错原因是没有添加 image_transport 软件包
*/
void callBck(const sensor_msgs::Image::ConstPtr& left_image_msg, const sensor_msgs::Image::ConstPtr& right_image_msg, cv::Mat& depth_mat)
{
    //这里libSGM要求输入图片格式为单通道CV_8U格式，这里必须调整输入格式为MONO8单通道格式，不能使用BGR8这种三通道格式
    cv_bridge::CvImagePtr left_image_ptr = cv_bridge::toCvCopy(*left_image_msg,  sensor_msgs::image_encodings::MONO8);
    cv_bridge::CvImagePtr right_image_ptr = cv_bridge::toCvCopy(*right_image_msg,  sensor_msgs::image_encodings::MONO8);
    cv::Mat left_image_raw = left_image_ptr->image;
    cv::Mat right_image_raw = right_image_ptr->image;
    cv::Mat left_image, right_image;
    left_image_raw.convertTo(left_image, CV_8U);
    right_image_raw.convertTo(right_image, CV_8U);
    const int disp_size = 128;

    //检查图像格式是否满足相关要求
    ASSERT_MSG(!left_image.empty() && !right_image.empty(), "imread failed.");
    ASSERT_MSG(left_image.size() == right_image.size() && left_image.type() == right_image.type(), "input images must be same size and type.");
    ASSERT_MSG(left_image.type() == CV_8U || left_image.type() == CV_16U, "input image format must be CV_8U or CV_16U.");
    ASSERT_MSG(disp_size == 64 || disp_size == 128 || disp_size == 256, "disparity size must be 64, 128 or 256.");

    //设置图像相关尺寸大小和数据大小
    const int width = left_image.cols;
    const int height = left_image.rows;

    const int input_depth = left_image.type() == CV_8U ? 8 : 16;
    const int input_bytes = input_depth * width * height / 8;
    const int output_depth = disp_size < 256 ? 8 : 16;
    const int output_bytes = output_depth * width * height / 8;

    //初始化StereoSGM类
    //sgm(图片宽度，图片高度，视差最大值，输入图片位宽， 输出视差图位宽， )
    sgm::StereoSGM sgm(width, height, disp_size, input_depth, output_depth, sgm::EXECUTE_INOUT_CUDA2CUDA);

    const int invalid_disp = output_depth == 8
            ? static_cast< uint8_t>(sgm.get_invalid_disparity())
            : static_cast<uint16_t>(sgm.get_invalid_disparity());

    cv::Mat disparity(height, width, output_depth == 8 ? CV_8U : CV_16U);
    cv::Mat disparity_8u, disparity_color;
    //在CUDA中创建相关数据的buffer
    device_buffer d_I1(input_bytes), d_I2(input_bytes), d_disparity(output_bytes);

    //将图片数据从host复制到GPU中
    cudaMemcpy(d_I1.data, left_image.data, input_bytes, cudaMemcpyHostToDevice);
    cudaMemcpy(d_I2.data, right_image.data, input_bytes, cudaMemcpyHostToDevice);

    const auto t1 = std::chrono::system_clock::now();
    //在GPU中执行相关运算
    sgm.execute(d_I1.data, d_I2.data, d_disparity.data);
    cudaDeviceSynchronize();

    const auto t2 = std::chrono::system_clock::now();
    const auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
    const double fps = 1e6 / duration;
    //将结果从GPU复制到本地设备
    cudaMemcpy(disparity.data, d_disparity.data, output_bytes, cudaMemcpyDeviceToHost);

    disparity.convertTo(disparity_8u, CV_8U, 255. / disp_size);
    depth_mat = disparity_8u;
    cv::applyColorMap(disparity_8u, disparity_color, cv::COLORMAP_JET);
    disparity_color.setTo(cv::Scalar(0, 0, 0), disparity == invalid_disp);
    cv::putText(disparity_color, format_string("sgm execution time: %4.1f[msec] %4.1f[FPS]", 1e-3 * duration, fps),
                cv::Point(50, 50), 2, 0.75, cv::Scalar(255, 255, 255));
    std::cout << disparity_color.size() << std::endl;
    //在转化为待发布消息时，必须将原有CV_8U单通道格式图片转换为3通道格式图片
    sensor_msgs::ImagePtr msg = cv_bridge::CvImage(std_msgs::Header(), "bgr8", disparity_color).toImageMsg();
    depth_pub.publish(*msg);
}

int main(int argc, char* argv[])
{
    //ros
    ros::init(argc, argv, "stereo_msg_ros");
    if(!ros::ok) return -1;
    ros::NodeHandle nh;
    depth_pub = nh.advertise<sensor_msgs::Image>("/depth/SGM", 1);
    cv::Mat depth_mat;

    message_filters::Subscriber<sensor_msgs::Image> left_image_sub(nh,"/zed/zed_node/left_raw/image_raw_gray",1);
    message_filters::Subscriber<sensor_msgs::Image> right_image_sub(nh,"/zed/zed_node/right_raw/image_raw_gray",1);
    typedef message_filters::sync_policies::ApproximateTime<sensor_msgs::Image, sensor_msgs::Image> synPolicy;
    message_filters::Synchronizer<synPolicy> sync(synPolicy(10), left_image_sub, right_image_sub);
    sync.registerCallback(boost::bind(&callBck, _1, _2, depth_mat));

    ros::AsyncSpinner spinner(3);

    ros::spin();
    return 0;
}
